# -*-coding:utf-8-*-
# date:2021-04-15
# author: likecy
# function : interfence

import os
import argparse
import torch
import torch.nn as nn
from data_iter.datasets import letterbox
import numpy as np

import time
import datetime
import os
import math
from datetime import datetime
import cv2
import torch.nn.functional as F
import xml.etree.cElementTree as ET
from models.build_model import *
from utils.model_utils import *
from sklearn.metrics import precision_score, recall_score, f1_score
import sys
# from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from utils.common_utils import *

MODEL_NAMES = {
    'alexnet': Alexnet,
    'googlenet': Googlenet,
    'resnet18': Resnet18,
    'resnet50': Resnet50,
    'resnet101': Resnet101,
    'resnet152': Resnet152,
    'resnext101_32x8d': Resnext101_32x8d,
    'resnext101_32x16d': Resnext101_32x16d,
    'resnext101_32x48d': Resnext101_32x48d,
    'resnext101_32x32d': Resnext101_32x32d,
    'densenet121': Densenet121,
    'densenet169': Densenet169,
    'moblienetv2': Mobilenetv2,
    'efficientnet-b7': Efficientnet,
    'efficientnet-b0': Efficientnet,
    'efficientnet-b8': Efficientnet,
    'squeezenet1_0': Squeezenet1_0,
    'squeezenet1_1': Squeezenet1_1,
    'shufflenet_v2_x0_5': Shufflenet_v2_x0_5,
    'shufflenet_v2_x1_0': Shufflenet_v2_x1_0,
    'shufflenet_v2_x1_5': Shufflenet_v2_x1_5,
    'shufflenet_v2_x2_0': Shufflenet_v2_x2_0
}

All_used_test_path = '/Users/chenyun/xjdlab/CT数据识别/tests/'  # 测试图片文件夹路径


def eval_models_pth_test_image_list(used_model_name, used_test_model, epoch):
    # ---------------------------------------------------------------------------
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    used_test_path = All_used_test_path  # 测试图片文件夹路径
    img_size = [224, 224]
    if not used_model_name.startswith('efficientnet'):
        model_ = MODEL_NAMES[used_model_name](num_classes=2)
    else:
        model_ = MODEL_NAMES[used_model_name](
            model_name=used_model_name, num_classes=2)
        # print(model_)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    model_ = model_.to(device)

    chkpt = torch.load(used_test_model, map_location=device)
    model_.load_state_dict(chkpt)
    print('load test model : {}'.format(used_test_model))

    model_.eval()  # 设置为前向推断模式
    # ----------------------------------------------------------------
    y_true = []
    y_pred = []
    dict_r = {}
    dict_p = {}
    dict_static = {}
    test_path = os.listdir(used_test_path)
    for idx, doc in enumerate(sorted(test_path, key=lambda x: int(x.split('-')[0]), reverse=False)):
        if doc not in dict_static.keys():
            dict_static[idx] = doc
            dict_r[doc] = 0
            dict_p[doc] = 0
    # ---------------------------------------------------------------- 预测图片
    print(dict_static)
    font = cv2.FONT_HERSHEY_SIMPLEX
    with torch.no_grad():
        for idx, doc in enumerate(sorted(os.listdir(used_test_path), key=lambda x: int(x.split('-')[0]), reverse=False)):
            i = 0
            gt_label = idx
            for file in os.listdir(used_test_path+doc):
                if ".jpg" not in file:
                    continue
                print('------>>> {} - gt_label : {}'.format(file, gt_label))

                img = cv2.imread(used_test_path + doc+'/' + file)

                img_ = cv2.resize(
                    img, (img_size[1], img_size[0]), interpolation=cv2.INTER_CUBIC)

                img_ = img_.astype(np.float32)
                img_ = (img_-128.)/256.

                img_ = img_.transpose(2, 0, 1)
                img_ = torch.from_numpy(img_)
                img_ = img_.unsqueeze_(0)

                if use_cuda:
                    img_ = img_.cuda()  # (bs, 3, h, w)

                pre_ = model_(img_.float())

                outputs = F.softmax(pre_, dim=1)
                outputs = outputs[0]

                output = outputs.cpu().detach().numpy()
                output = np.array(output)

                max_index = np.argmax(output)
                score_ = output[max_index]
                y_true.append(gt_label)
                y_pred.append(max_index)
                print('gt {} - {} -- pre {}     --->>>    confidence {}'.format(doc,
                      gt_label, max_index, score_))
                dict_p[dict_static[max_index]] += 1
                if gt_label == max_index:
                    dict_r[doc] += 1
    cv2.destroyAllWindows()
    # Top1 的每类预测精确度。
    print('\n-----------------------------------------------\n')
    p = precision_score(y_true, y_pred)  # 输出结果0.5
    r = recall_score(y_true, y_pred)  # 输出结果0.333
    f1 = f1_score(y_true, y_pred)  # 输出0.4
    print("精确度：", p)
    print("召回率:", r)
    print("F1值", f1)
    fs = open('./log/'+used_model_name + '_' + str(epoch) +
              '_interface_result.txt', "w", encoding='utf-8')
    fs.write('\n-----------------------------------------------\n')
    fs.write("eval form :"+used_test_model)
    fs.write('\n-----------------------------------------------\n')
    fs.write("\n 精确度:{}".format(p))
    fs.write("\n 召回率:{}".format(r))
    fs.write("\n F1_source值:{}".format(f1))
    fs.close()


if __name__ == "__main__":
    f_log = Logger(
        filename='./log/test_eval_{}.log'.format(time.strftime("%Y-%m-%d_%H-%M-%S", loc_time)))
    print("开始记录日志～")
    sys.stdout = f_log
    traind_pth = "/Users/chenyun/xjdlab/classification/test-models/"
    pth_models_names = os.listdir(traind_pth)
    for model_name in pth_models_names:
        if(model_name != ".DS_Store"):
            print(model_name)
            for test_pth in os.listdir(traind_pth + model_name):
                if(test_pth != ".DS_Store"):
                    file_name = os.path.basename(test_pth)
                    file_name = file_name.split('.')[0]
                    epoch = file_name.split('-')[-1]
                    print(epoch)
                    user_test_pth = traind_pth + model_name + '/'+test_pth
                    print("epoch: {} model name:{},pth: {}".format(
                        epoch, model_name, user_test_pth))
                    eval_models_pth_test_image_list(
                        model_name, user_test_pth, epoch)
